# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers: DongYang Li <lidongyang2001@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
from collections import namedtuple
import tempfile
import pickle
import itertools
from jittor.einops.einops import (rearrange, reduce, _enumerate_directions, _reductions)
from jittor.einops import EinopsError
import jittor as jt
import numpy
import unittest

# tests/__init__.py
import os
from jittor.einops import _backends
import warnings

flag_to_bool = {
    '': False,
    '0': False,
    '1': True,
}


def collect_test_backends(symbolic=False, layers=False):
    """
    :param symbolic: symbolic or imperative frameworks?
    :param layers: layers or operations?
    :return: list of backends satisfying set conditions
    """
    if not symbolic:
        if not layers:
            backend_types = [
                _backends.NumpyBackend,
                _backends.JittorBackend,
            ]
        else:
            backend_types = [
                _backends.JittorBackend,
            ]
    else:
        backend_types = []
    result = []
    for backend_type in backend_types:
        try:
            result.append(backend_type())
        except ImportError:
            # problem with backend installation fails a specific test function,
            # but will be skipped in all other test cases
            warnings.warn('backend could not be initialized for tests: {}'.format(backend_type))
    return result


# test/test_ops.py

imp_op_backends = collect_test_backends(symbolic=False, layers=False)

# test/test_layer.py


class TestSlice(unittest.TestCase):

    def test_anonymous_axes(self):
        x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6])
        for pattern, axis_dimensions in test_cases_repeat_anonymous:
            check_reversion(x, pattern, **axis_dimensions)

    def test_repeat_imperatives(self):
        x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
        for backend in imp_op_backends:
            print('Repeat tests for ', backend.framework_name)

            for pattern, axis_dimensions in repeat_test_cases:
                expected = reduce(x, pattern, reduction='repeat', **axis_dimensions)
                converted = backend.from_numpy(x)
                repeated = reduce(converted, pattern, reduction='repeat', **axis_dimensions)
                result = backend.to_numpy(repeated)
                assert numpy.array_equal(result, expected)

    def test_repeat_numpy(self):
        # check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well
        x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
        x1 = reduce(x, 'a b c -> copy a b c ', reduction='repeat', copy=1)
        assert numpy.array_equal(x[None], x1)
        for pattern, axis_dimensions in repeat_test_cases:
            check_reversion(x, pattern, **axis_dimensions)

    def test_tiling_imperatives(self):
        for backend in imp_op_backends:
            print('Tiling tests for ', backend.framework_name)
            input = numpy.arange(2 * 3 * 5, dtype='int64').reshape([2, 1, 3, 1, 5])
            test_cases = [
                (1, 1, 1, 1, 1),
                (1, 2, 1, 3, 1),
                (3, 1, 1, 4, 1),
            ]
            for repeats in test_cases:
                expected = numpy.tile(input, repeats)
                converted = backend.from_numpy(input)
                repeated = backend.tile(converted, repeats)
                result = backend.to_numpy(repeated)
                assert numpy.array_equal(result, expected)

    def test_gradients_imperatives(self):
        # lazy - just checking reductions
        for reduction in _reductions:
            x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype('float32')
            results = {}
            for backend in imp_op_backends:
                y0 = backend.from_numpy(x)
                if not 'jittor' in backend.framework_name and not hasattr(y0, 'grad'):
                    continue
                y1 = reduce(y0, 'a b c -> c a', reduction=reduction)
                y2 = reduce(y1, 'c a -> a c', reduction=reduction)
                y3 = reduce(y2, 'a (c1 c2) -> a', reduction=reduction, c1=2)
                y4 = reduce(y3, '... -> ', reduction=reduction)
                if 'jittor' in backend.framework_name:
                    grad = backend.jittor.grad(y4, y0)
                else:
                    y4.backward()
                    grad = y0.grad
                results[backend.framework_name] = backend.to_numpy(grad)

            print('comparing gradients for', results.keys())
            for name1, grad1 in results.items():
                for name2, grad2 in results.items():
                    assert numpy.allclose(grad1, grad2), [name1, name2, 'provided different gradients']

    def test_concatenations_and_stacking(self):
        for backend in imp_op_backends:
            print('testing shapes for ', backend.framework_name)
            for n_arrays in [1, 2, 5]:
                shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
                for shape in shapes:
                    if (backend.framework_name == 'jittor')\
                            and len(shape) == 0:
                        # jittor stores scalar in 1d array
                        continue
                    arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)]
                    arrays2 = [backend.from_numpy(array) for array in arrays1]
                    result0 = numpy.asarray(arrays1)
                    result1 = rearrange(arrays1, '...->...')
                    result2 = rearrange(arrays2, '...->...')
                    assert numpy.array_equal(result0, result1)
                    assert numpy.array_equal(result1, backend.to_numpy(result2))

                    result1 = rearrange(arrays1, 'b ... -> ... b')
                    result2 = rearrange(arrays2, 'b ... -> ... b')
                    assert numpy.array_equal(result1, backend.to_numpy(result2))

    def test_enumerating_directions(self):
        for backend in imp_op_backends:
            print('testing directions for', backend.framework_name)
            for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]:
                if (backend.framework_name == 'jittor')\
                        and len(shape) == 0:
                    # jittor stores scalar in 1d array
                    continue
                x = numpy.arange(numpy.prod(shape)).reshape(shape)
                axes1 = _enumerate_directions(x)
                axes2 = _enumerate_directions(backend.from_numpy(x))
                assert len(axes1) == len(axes2) == len(shape)
                for ax1, ax2 in zip(axes1, axes2):
                    ax2 = backend.to_numpy(ax2)
                    assert ax1.shape == ax2.shape
                    assert numpy.allclose(ax1, ax2)

    def test_reduction_with_callable_imperatives(self):
        x_numpy = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]).astype('float32')
        x_numpy /= x_numpy.max()

        def logsumexp_jittor(x, tuple_of_axes):
            import jittor as jt
            return jt.nn.logsumexp(x, tuple_of_axes)

        def logsumexp_numpy(x, tuple_of_axes):
            # very naive logsumexp to compare to
            minused = x.max(tuple_of_axes)
            y = x - x.max(tuple_of_axes, keepdims=True)
            y = numpy.exp(y)
            y = numpy.sum(y, axis=tuple_of_axes)
            return numpy.log(y) + minused

        from jittor.einops._backends import JittorBackend, NumpyBackend
        backend2callback = {
            JittorBackend.framework_name: logsumexp_jittor,
            NumpyBackend.framework_name: logsumexp_numpy,
        }

        for backend in imp_op_backends:
            if backend.framework_name not in backend2callback:
                continue

            backend_callback = backend2callback[backend.framework_name]

            x_backend = backend.from_numpy(x_numpy)
            for pattern1, pattern2 in equivalent_reduction_patterns:
                print('Test reduction with callable for ', backend.framework_name, pattern1, pattern2)
                output_numpy = reduce(x_numpy, pattern1, reduction=logsumexp_numpy)
                output_backend = reduce(x_backend, pattern1, reduction=backend_callback)
                assert numpy.allclose(
                    output_numpy,
                    backend.to_numpy(output_backend),
                )

    def test_reduction_stress_imperatives(self):
        for backend in imp_op_backends:
            print('Stress-testing reduction for ', backend.framework_name)
            for reduction in _reductions + ('rearrange',):
                dtype = 'int64'
                coincide = numpy.array_equal
                if reduction in ['mean', 'prod']:
                    dtype = 'float64'
                    coincide = numpy.allclose
                for n_axes in range(11):
                    shape = numpy.random.randint(2, 4, size=n_axes)
                    permutation = numpy.random.permutation(n_axes)
                    skipped = 0 if reduction == 'rearrange' else numpy.random.randint(n_axes + 1)
                    left = ' '.join('x' + str(i) for i in range(n_axes))
                    right = ' '.join('x' + str(i) for i in permutation[skipped:])
                    pattern = left + '->' + right
                    x = numpy.arange(1, 1 + numpy.prod(shape), dtype=dtype).reshape(shape)
                    if reduction == 'prod':
                        x /= x.mean()  # to avoid overflows
                    result1 = reduce(x, pattern, reduction=reduction)
                    result2 = x.transpose(permutation)
                    if skipped > 0:
                        result2 = getattr(result2, reduction)(axis=tuple(range(skipped)))
                    assert coincide(result1, result2)
                    check_op_against_numpy(backend, x, pattern, reduction=reduction, axes_lengths={}, is_symbolic=False)

    def test_reduction_imperatives(self):
        for backend in imp_op_backends:
            print('Reduction tests for ', backend.framework_name)
            for reduction in _reductions:
                # slight redundancy for simpler order - numpy version is evaluated multiple times
                input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype='int64').reshape([2, 3, 4, 5, 6])
                if reduction in ['mean', 'prod']:
                    input = input / input.astype('float64').mean()
                test_cases = [
                    ['a b c d e -> ', {},
                     getattr(input, reduction)()],
                    ['a ... -> ', {},
                     getattr(input, reduction)()],
                    ['(a1 a2) ... (e1 e2) -> ', dict(a1=1, e2=2),
                     getattr(input, reduction)()],
                    ['a b c d e -> (e c) a', {},
                     getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
                    ['a ... c d e -> (e c) a', {},
                     getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
                    ['a b c d e ... -> (e c) a', {},
                     getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
                    ['a b c d e -> (e c a)', {},
                     getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])],
                    ['(a a2) ... -> (a2 a) ...', dict(a2=1),
                     input],
                ]
                for pattern, axes_lengths, expected_result in test_cases:
                    result = reduce(backend.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths)
                    result = backend.to_numpy(result)
                    assert numpy.allclose(result, expected_result)

    def test_rearrange_permutations_numpy(self):
        # tests random permutation of axes against two independent numpy ways
        for n_axes in range(1, 10):
            input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
            permutation = numpy.random.permutation(n_axes)
            left_expression = ' '.join('i' + str(axis) for axis in range(n_axes))
            right_expression = ' '.join('i' + str(axis) for axis in permutation)
            expression = left_expression + ' -> ' + right_expression
            result = rearrange(input, expression)

            for pick in numpy.random.randint(0, 2, [10, n_axes]):
                assert input[tuple(pick)] == result[tuple(pick[permutation])]

        for n_axes in range(1, 10):
            input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
            permutation = numpy.random.permutation(n_axes)
            left_expression = ' '.join('i' + str(axis) for axis in range(n_axes)[::-1])
            right_expression = ' '.join('i' + str(axis) for axis in permutation[::-1])
            expression = left_expression + ' -> ' + right_expression
            result = rearrange(input, expression)
            assert result.shape == input.shape
            expected_result = numpy.zeros_like(input)
            for original_axis, result_axis in enumerate(permutation):
                expected_result |= ((input >> original_axis) & 1) << result_axis

            assert numpy.array_equal(result, expected_result)

    def test_rearrange_consistency_numpy(self):
        shape = [1, 2, 3, 5, 7, 11]
        x = numpy.arange(numpy.prod(shape)).reshape(shape)
        for pattern in [
            'a b c d e f -> a b c d e f',
            'b a c d e f -> a b d e f c',
            'a b c d e f -> f e d c b a',
            'a b c d e f -> (f e) d (c b a)',
            'a b c d e f -> (f e d c b a)',
        ]:
            result = rearrange(x, pattern)
            assert len(numpy.setdiff1d(x, result)) == 0
            assert result.dtype == x.dtype

        result = rearrange(x, 'a b c d e f -> a (b) (c d e) f')
        assert numpy.array_equal(x.flatten(), result.flatten())

        result = rearrange(x, 'a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11')
        assert numpy.array_equal(x, result)

        result1 = rearrange(x, 'a b c d e f -> f e d c b a')
        result2 = rearrange(x, 'f e d c b a -> a b c d e f')
        assert numpy.array_equal(result1, result2)

        result = rearrange(rearrange(x, 'a b c d e f -> (f d) c (e b) a'), '(f d) c (e b) a -> a b c d e f', b=2, d=5)
        assert numpy.array_equal(x, result)

        sizes = dict(zip('abcdef', shape))
        temp = rearrange(x, 'a b c d e f -> (f d) c (e b) a', **sizes)
        result = rearrange(temp, '(f d) c (e b) a -> a b c d e f', **sizes)
        assert numpy.array_equal(x, result)

        x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4])
        result = rearrange(x2, 'a b c -> b c a')
        assert x2[1, 2, 3] == result[2, 3, 1]
        assert x2[0, 1, 2] == result[1, 2, 0]

    def test_ellipsis_ops_imperative(self):
        """ Checking various patterns against numpy """
        x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
        for is_symbolic in [True, False]:
            for backend in collect_test_backends(symbolic=is_symbolic, layers=False):
                for pattern in identity_patterns + list(itertools.chain(*equivalent_rearrange_patterns)):
                    check_op_against_numpy(backend, x, pattern, axes_lengths={},
                                           reduction='rearrange', is_symbolic=is_symbolic)

                for reduction in ['min', 'max', 'sum']:
                    for pattern in itertools.chain(*equivalent_reduction_patterns):
                        check_op_against_numpy(backend, x, pattern, axes_lengths={},
                                               reduction=reduction, is_symbolic=is_symbolic)

    def test_ellipsis_ops_numpy(self):
        x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
        for pattern in identity_patterns:
            assert numpy.array_equal(x, rearrange(x, pattern)), pattern

        for pattern1, pattern2 in equivalent_rearrange_patterns:
            assert numpy.array_equal(rearrange(x, pattern1), rearrange(x, pattern2))

        for reduction in ['min', 'max', 'sum']:
            for pattern1, pattern2 in equivalent_reduction_patterns:
                assert numpy.array_equal(reduce(x, pattern1, reduction=reduction),
                                         reduce(x, pattern2, reduction=reduction))

        # now just check coincidence with numpy
        all_rearrange_patterns = [*identity_patterns]
        for pattern_pairs in equivalent_rearrange_patterns:
            all_rearrange_patterns.extend(pattern_pairs)

    def test_collapsed_ellipsis_errors_out(self):
        x = numpy.zeros([1, 1, 1, 1, 1])
        rearrange(x, 'a b c d ... ->  a b c ... d')
        error = 0
        try:
            rearrange(x, 'a b c d (...) ->  a b c ... d')
        except Exception as e:
            error = 1
        assert error == 1

        rearrange(x, '... ->  (...)')
        error = 0
        try:
            rearrange(x, '(...) -> (...)')
        except Exception as e:
            error = 1
        assert error == 1

    def test_rearrange_imperative(self):
        for backend in collect_test_backends(symbolic=False, layers=True):
            print('Test layer for ', backend.framework_name)

            for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns:
                x = numpy.arange(numpy.prod(input_shape), dtype='float32').reshape(input_shape)
                result_numpy = rearrange(x, pattern, **axes_lengths)
                layer = backend.layers().Rearrange(pattern, **axes_lengths)
                for shape in wrong_shapes:
                    try:
                        layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
                    except:
                        pass
                    else:
                        raise AssertionError('Failure expected')

                # simple pickling / unpickling
                layer2 = pickle.loads(pickle.dumps(layer))
                result1 = backend.to_numpy(layer(backend.from_numpy(x)))
                result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
                assert numpy.allclose(result_numpy, result1)
                assert numpy.allclose(result1, result2)

                just_sum = backend.layers().Reduce('...->', reduction='sum')

               
                variable = backend.from_numpy(x)
                result = just_sum(layer(variable))

                if 'jittor' in backend.framework_name:
                    grad = backend.jittor.grad(result, variable)
                else:
                    result.backward()
                    grad = variable.grad

                assert numpy.allclose(backend.to_numpy(grad), 1)

    def test_reduce_imperative(self):
        for backend in collect_test_backends(symbolic=False, layers=True):
            print('Test layer for ', backend.framework_name)
            for reduction in _reductions:
                for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns:
                    print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes)
                    x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype='float32').reshape(input_shape)
                    x /= x.mean()
                    result_numpy = reduce(x, pattern, reduction, **axes_lengths)
                    layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
                    for shape in wrong_shapes:
                        try:
                            layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
                        except:
                            pass
                        else:
                            raise AssertionError('Failure expected')

                    # simple pickling / unpickling
                    layer2 = pickle.loads(pickle.dumps(layer))
                    result1 = backend.to_numpy(layer(backend.from_numpy(x)))
                    result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
                    assert numpy.allclose(result_numpy, result1)
                    assert numpy.allclose(result1, result2)

                    just_sum = backend.layers().Reduce('...->', reduction='sum')

                
                    variable = backend.from_numpy(x)
                    result = just_sum(layer(variable))

                    if 'jittor' in backend.framework_name:
                        grad = backend.jittor.grad(result, variable)
                        grad = backend.to_numpy(grad)
                    else:
                        result.backward()
                        grad = backend.to_numpy(variable.grad)
                    if reduction == 'sum':
                        assert numpy.allclose(grad, 1)
                    if reduction == 'mean':
                        assert numpy.allclose(grad, grad.min())
                    if reduction in ['max', 'min']:
                        assert numpy.all(numpy.in1d(grad, [0, 1]))
                        assert numpy.sum(grad) > 0.5

    def test_jittor_layer(self):
        has_jittor = any(backend.framework_name == 'jittor' for backend in collect_test_backends(symbolic=False, layers=True))
        if has_jittor:
            # checked that jittor present
            import jittor

            rtol = 1e-05
            atol = 1e-08
            def allclose(input, other): return jittor.all(jittor.abs(input-other) <= atol+rtol*jittor.abs(other))
            model1 = create_jittor_model(use_reduce=True)
            model2 = create_jittor_model(use_reduce=False)
            input = jittor.randn([10, 3, 32, 32])
            # random models have different predictions
            assert not allclose(model1(input), model2(input))
            model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict())))
            assert allclose(model1(input), model2(input))


testcase = namedtuple('testcase', ['pattern', 'axes_lengths', 'input_shape', 'wrong_shapes'])

rearrangement_patterns = [
    testcase('b c h w -> b (c h w)', dict(c=20), (10, 20, 30, 40),
             [(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]]),
    testcase('b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1', dict(h2=2, w2=2), (10, 20, 30, 40),
             [(), (1, 1, 1, 1), (1, 10, 3), ()]),
    testcase('b ... c -> c b ...', dict(b=10), (10, 20, 30),
             [(), (10,), (5, 10)]),
]

reduction_patterns = rearrangement_patterns + [
    testcase('b c h w -> b ()', dict(b=10), (10, 20, 30, 40),
             [(10,), (10, 20, 30)]),
    testcase('b c (h1 h2) (w1 w2) -> b c h1 w1', dict(h1=15, h2=2, w2=2), (10, 20, 30, 40),
             [(10, 20, 31, 40)]),
    testcase('b ... c -> b', dict(b=10), (10, 20, 30, 40),
             [(10,), (11, 10)]),
]

equivalent_reduction_patterns = [
    ('a b c d e -> ', ' ... ->  '),
    ('a b c d e -> (e a)', 'a ... e -> (e a)'),
    ('a b c d e -> d (a e)', ' a b c d e ... -> d (a e) '),
    ('a b c d e -> (a b)', ' ... c d e  -> (...) '),
]

equivalent_rearrange_patterns = [
    ('a b c d e -> (a b) c d e', 'a b ... -> (a b) ... '),
    ('a b c d e -> a b (c d) e', '... c d e -> ... (c d) e'),
    ('a b c d e -> a b c d e', '... -> ... '),
    ('a b c d e -> (a b c d e)', '... ->  (...)'),
    ('a b c d e -> b (c d e) a', 'a b ... -> b (...) a'),
    ('a b c d e -> b (a c d) e', 'a b ... e -> b (a ...) e'),
]

identity_patterns = [
    '...->...',
    'a b c d e-> a b c d e',
    'a b c d e ...-> ... a b c d e',
    'a b c d e ...-> a ... b c d e',
    '... a b c d e -> ... a b c d e',
    'a ... e-> a ... e',
    'a ... -> a ... ',
    'a ... c d e -> a (...) c d e',
]

test_cases_repeat_anonymous = [
    # all assume that input has shape [1, 2, 4, 6]
    ('a b c d -> c a d b', dict()),
    ('a b c d -> (c 2 d a b)', dict(a=1, c=4, d=6)),
    ('1 b c d -> (d copy 1) 3 b c ', dict(copy=3)),
    ('1 ...  -> 3 ... ', dict()),
    ('() ... d -> 1 (copy1 d copy2) ... ', dict(copy1=2, copy2=3)),
    ('1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)', dict()),

]

repeat_test_cases = [
    # all assume that input has shape [2, 3, 5]
    ('a b c -> c a b', dict()),
    ('a b c -> (c copy a b)', dict(copy=2, a=2, b=3, c=5)),
    ('a b c -> (a copy) b c ', dict(copy=1)),
    ('a b c -> (c a) (copy1 b copy2)', dict(a=2, copy1=1, copy2=2)),
    ('a ...  -> a ... copy', dict(copy=4)),
    ('... c -> ... (copy1 c copy2)', dict(copy1=1, copy2=2)),
    ('...  -> ... ', dict()),
    (' ...  -> copy1 ... copy2 ', dict(copy1=2, copy2=3)),
    ('a b c  -> copy1 a copy2 b c () ', dict(copy1=2, copy2=1)),
]


def check_reversion(x, repeat_pattern, **sizes):
    """Checks repeat pattern by running reduction """
    left, right = repeat_pattern.split('->')
    reduce_pattern = right + '->' + left
    repeated = reduce(x, repeat_pattern, reduction='repeat', **sizes)
    reduced_min = reduce(repeated, reduce_pattern, reduction='min', **sizes)
    reduced_max = reduce(repeated, reduce_pattern, reduction='max', **sizes)
    assert numpy.array_equal(x, reduced_min)
    assert numpy.array_equal(x, reduced_max)


def check_op_against_numpy(backend, numpy_input, pattern, axes_lengths, reduction='rearrange', is_symbolic=False):
    """
    Helper to test result of operation (rearrange or transpose) against numpy
    if reduction == 'rearrange', rearrange op is tested, otherwise reduce
    """
    if len(numpy_input.shape) == 0:
        return

    def operation(x):
        if reduction == 'rearrange':
            return rearrange(x, pattern, **axes_lengths)
        else:
            return reduce(x, pattern, reduction, **axes_lengths)

    numpy_result = operation(numpy_input)
    check_equal = numpy.array_equal
    p_none_dimension = 0.5
    if 'jittor' in backend.framework_name:
        check_equal = numpy.allclose
        p_none_dimension = 0
    if is_symbolic:
        symbol_shape = [d if numpy.random.random() >= p_none_dimension else None for d in numpy_input.shape]
        symbol = backend.create_symbol(shape=symbol_shape)
        result_symbol = operation(symbol)
        backend_result = backend.eval_symbol(result_symbol, [(symbol, numpy_input)])
    else:
        backend_result = operation(backend.from_numpy(numpy_input))
        backend_result = backend.to_numpy(backend_result)

    check_equal(numpy_result, backend_result)


def create_jittor_model(use_reduce=False):
    from jittor.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
    from jittor.einops.layers.jittor import Rearrange, Reduce, EinMix
    return Sequential(
        Conv2d(3, 6, kernel_size=(5, 5)),
        Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2),
        Conv2d(6, 16, kernel_size=(5, 5)),
        Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2),
        Rearrange('b c h w -> b (c h w)'),
        Linear(16 * 5 * 5, 120),
        ReLU(),
        Linear(120, 84),
        ReLU(),
        EinMix('b c1 -> (b c2)', weight_shape='c1 c2', bias_shape='c2', c1=84, c2=84),
        EinMix('(b c2) -> b c3', weight_shape='c2 c3', bias_shape='c3', c2=84, c3=84),
        Linear(84, 10),
    )


if __name__ == '__main__':
    unittest.main()